from __future__ import annotations

import datetime
import itertools
import tempfile
from collections.abc import Generator
from collections.abc import Iterator
from typing import Any
from typing import cast
from typing import ClassVar

import pywikibot.time  # type: ignore[import-untyped]
from pywikibot import pagegenerators  # type: ignore[import-untyped]
from pywikibot import textlib  # type: ignore[import-untyped]

from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.mediawiki.family import family_class_dispatch
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.utils.logger import setup_logger


logger = setup_logger()

pywikibot.config.base_dir = tempfile.TemporaryDirectory().name


def pywikibot_timestamp_to_utc_datetime(
    timestamp: pywikibot.time.Timestamp,
) -> datetime.datetime:
    """Convert a pywikibot timestamp to a datetime object in UTC.

    Args:
        timestamp: The pywikibot timestamp to convert.

    Returns:
        A datetime object in UTC.
    """
    return datetime.datetime.astimezone(timestamp, tz=datetime.timezone.utc)


def get_doc_from_page(
    page: pywikibot.Page, site: pywikibot.Site | None, source_type: DocumentSource
) -> Document:
    """Generate Onyx Document from a MediaWiki page object.

    Args:
        page: Page from a MediaWiki site.
        site: MediaWiki site (used to parse the sections of the page using the site template, if available).
        source_type: Source of the document.

    Returns:
        Generated document.
    """
    page_text = page.text
    sections_extracted: textlib.Content = textlib.extract_sections(page_text, site)

    sections = [
        TextSection(
            link=f"{page.full_url()}#" + section.heading.replace(" ", "_"),
            text=section.title + section.content,
        )
        for section in sections_extracted.sections
    ]
    sections.append(
        TextSection(
            link=page.full_url(),
            text=sections_extracted.header,
        )
    )

    return Document(
        source=source_type,
        title=page.title(),
        doc_updated_at=pywikibot_timestamp_to_utc_datetime(
            page.latest_revision.timestamp
        ),
        sections=cast(list[TextSection | ImageSection], sections),
        semantic_identifier=page.title(),
        metadata={"categories": [category.title() for category in page.categories()]},
        id=f"MEDIAWIKI_{page.pageid}_{page.full_url()}",
    )


class MediaWikiConnector(LoadConnector, PollConnector):
    """A connector for MediaWiki wikis.

    Args:
        hostname: The hostname of the wiki.
        categories: The categories to include in the index.
        pages: The pages to include in the index.
        recurse_depth: The depth to recurse into categories. -1 means unbounded recursion.
        language_code: The language code of the wiki.
        batch_size: The batch size for loading documents.

    Raises:
        ValueError: If `recurse_depth` is not an integer greater than or equal to -1.
    """

    document_source_type: ClassVar[DocumentSource] = DocumentSource.MEDIAWIKI
    """DocumentSource type for all documents generated by instances of this class. Can be overridden for connectors
    tailored for specific sites."""

    def __init__(
        self,
        hostname: str,
        categories: list[str],
        pages: list[str],
        recurse_depth: int,
        language_code: str = "en",
        batch_size: int = INDEX_BATCH_SIZE,
    ) -> None:
        if recurse_depth < -1:
            raise ValueError(
                f"recurse_depth must be an integer greater than or equal to -1. Got {recurse_depth} instead."
            )
        # -1 means infinite recursion, which `pywikibot` will only do with `True`
        self.recurse_depth: bool | int = True if recurse_depth == -1 else recurse_depth

        self.batch_size = batch_size

        # short names can only have ascii letters and digits
        self.family = family_class_dispatch(hostname, "WikipediaConnector")()
        self.site = pywikibot.Site(fam=self.family, code=language_code)
        self.categories = [
            pywikibot.Category(
                self.site,
                (
                    f"{category.replace(' ', '_')}"
                    if category.startswith("Category:")
                    else f"Category:{category.replace(' ', '_')}"
                ),
            )
            for category in categories
        ]

        self.pages = []
        for page in pages:
            if not page:
                continue
            self.pages.append(pywikibot.Page(self.site, page))

    def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
        """Load credentials for a MediaWiki site.

        Note:
            For most read-only operations, MediaWiki API credentials are not necessary.
            This method can be overridden in the event that a particular MediaWiki site
            requires credentials.
        """
        return None

    def _get_doc_batch(
        self,
        start: SecondsSinceUnixEpoch | None = None,
        end: SecondsSinceUnixEpoch | None = None,
    ) -> Generator[list[Document], None, None]:
        """Request batches of pages from a MediaWiki site.

        Args:
            start: The beginning of the time period of pages to request.
            end: The end of the time period of pages to request.

        Yields:
            Lists of Documents containing each parsed page in a batch.
        """
        doc_batch: list[Document] = []

        # Pywikibot can handle batching for us, including only loading page contents when we finally request them.
        category_pages = [
            pagegenerators.PreloadingGenerator(
                pagegenerators.EdittimeFilterPageGenerator(
                    pagegenerators.CategorizedPageGenerator(
                        category, recurse=self.recurse_depth
                    ),
                    last_edit_start=(
                        datetime.datetime.fromtimestamp(start) if start else None
                    ),
                    last_edit_end=datetime.datetime.fromtimestamp(end) if end else None,
                ),
                groupsize=self.batch_size,
            )
            for category in self.categories
        ]

        # Since we can specify both individual pages and categories, we need to iterate over all of them.
        all_pages: Iterator[pywikibot.Page] = itertools.chain(
            self.pages, *category_pages
        )
        for page in all_pages:
            logger.info(
                f"MediaWikiConnector: title='{page.title()}' url={page.full_url()}"
            )
            doc_batch.append(
                get_doc_from_page(page, self.site, self.document_source_type)
            )
            if len(doc_batch) >= self.batch_size:
                yield doc_batch
                doc_batch = []
        if doc_batch:
            yield doc_batch

    def load_from_state(self) -> GenerateDocumentsOutput:
        """Load all documents from the source.

        Returns:
            A generator of documents.
        """
        return self.poll_source(None, None)

    def poll_source(
        self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
    ) -> GenerateDocumentsOutput:
        """Poll the source for new documents.

        Args:
            start: The start of the time range to poll.
            end: The end of the time range to poll.

        Returns:
            A generator of documents.
        """
        return self._get_doc_batch(start, end)


if __name__ == "__main__":
    HOSTNAME = "fallout.fandom.com"
    test_connector = MediaWikiConnector(
        hostname=HOSTNAME,
        categories=["Fallout:_New_Vegas_factions"],
        pages=["Fallout: New Vegas"],
        recurse_depth=1,
    )

    all_docs = list(test_connector.load_from_state())
    print("All docs", all_docs)
    current = datetime.datetime.now().timestamp()
    one_day_ago = current - 30 * 24 * 60 * 60  # 30 days

    latest_docs = list(test_connector.poll_source(one_day_ago, current))

    print("Latest docs", latest_docs)
